from tqdm import tqdm
import torch
from torch import float16
import pandas as pd
import time

from torch.profiler import profile, record_function, ProfilerActivity
from .trainer_utils import full_count_params, count_params_test, count_params_train, max_dist_M_I


def train(NN, optimizer, train_loader, validation_loader, test_loader, criterion, metric, epochs,
          metric_name='accuracy', device='cpu', count_bias=False, path=None,
          fine_tune=False, scheduler=None, save_weights=True, save_progress=False, save_name=''):
    """
    INPUTS:
    NN : neural network with custom layers and methods to optimize with dlra
    train/validation/test_loader : loader for datasets
    criterion : loss function
    metric : metric function
    epochs : number of epochs to train
    metric_name : name of the used metric
    count_bias : flag variable if to count biases in params_count or not
    path : path string for where to save the results
    OUTPUTS:
    running_data : Pandas dataframe with the results of the run
    """

    running_data = pd.DataFrame(data=None, columns=['epoch', 'cr', 'max ||M-I||', 'learning_rate', 'train_loss',
                                                    'train_' + metric_name + '(%)', 'validation_loss', \
                                                    'validation_' + metric_name + '(%)', 'test_' + metric_name + '(%)', \
                                                    'ranks', '# effective parameters', 'cr_test (%)',
                                                    '# effective parameters train', 'cr_train (%)', \
                                                    '# effective parameters train with grads', 'cr_train_grads (%)',
                                                    'timing batch forward'])

    total_params_full = full_count_params(NN, count_bias)
    total_params_full_grads = full_count_params(NN, count_bias, True)

    file_name = path

    def accuracy(outputs, labels):
        return torch.sum(torch.tensor(torch.argmax(outputs.detach(), axis=1) == labels, dtype=float16))

    metric = accuracy
    batch_size = train_loader.batch_size

    if not fine_tune:  # Main Training Loop

        if path is not None:
            file_name += '.csv'

        def trace_handler(p):  # Setup trace handler for profiling
            output = p.key_averages(group_by_stack_n=15).table(sort_by="self_cuda_time_total", row_limit=10)
            print(output)
            p.export_stacks("profiler_stacks_" + str(p.step_num) + ".txt", "self_cuda_time_total")
            # output = p.key_averages(group_by_stack_n=10).table(sort_by="self_cpu_time_total", row_limit=20)
            # print(output)
            # output = p.key_averages(group_by_stack_n=10).table(sort_by="cuda_memory_usage", row_limit=20)
            # print(output)
            print("trace_" + str(p.step_num) + ".json")
            p.export_chrome_trace("trace_" + str(p.step_num) + ".json")
            # print(p.key_averages(group_by_stack_n=5).table(sort_by="self_cuda_time_total", row_limit=10))
            # p.export_chrome_trace("stack_trace_" + str(p.step_num) + ".json")

        for epoch in tqdm(range(epochs)):

            print(f'epoch {epoch}---------------------------------------------')
            loss_hist = 0
            acc_hist = 0
            k = len(train_loader)
            average_batch_time = 0.0

            NN.train()

            with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], with_stack=True,
                         schedule=torch.profiler.schedule(wait=1, warmup=1, active=1),
                         on_trace_ready=trace_handler) as p:
                for i, data in enumerate(train_loader):  # train loop
                    NN.zero_grad()
                    optimizer.zero_grad()
                    # start = time.time()
                    inputs, labels = data
                    inputs, labels = inputs.to(device), labels.to(device)

                    def closure():
                        loss = NN.populate_gradients(inputs, labels, criterion, step='core')
                        return loss

                    loss, outputs = NN.populate_gradients(inputs, labels, criterion)
                    loss_hist += float(loss.item()) / (k * batch_size)
                    optimizer.step(closure=closure)

                    NN.set_step('test')  # why is this here

                    outputs = NN(inputs).detach().to(device)  # possibly expensive
                    acc_hist += float(metric(outputs, labels)) / (k * batch_size)
                    # stop = time.time() - start
                    # average_batch_time += stop / k

                    p.step()  # Evaluate time profiles
            NN.eval()

            with torch.no_grad():
                k = len(validation_loader)
                batch_size = validation_loader.batch_size
                loss_hist_val = 0.0
                acc_hist_val = 0.0
                for i, data in enumerate(validation_loader):  # validation
                    inputs, labels = data
                    inputs, labels = inputs.to(device), labels.to(device)
                    outputs = NN(inputs).detach().to(device)
                    loss_val = criterion(outputs, labels)
                    loss_hist_val += float(loss_val.item()) / (k * batch_size)
                    acc_hist_val += float(metric(outputs, labels)) / (k * batch_size)

                if test_loader != None:
                    k = len(test_loader)
                    loss_hist_test = 0.0
                    acc_hist_test = 0.0
                    batch_size = test_loader.batch_size
                    for i, data in enumerate(test_loader):  # validation
                        inputs, labels = data
                        inputs, labels = inputs.to(device), labels.to(device)
                        outputs = NN(inputs).detach().to(device)
                        loss_test = criterion(outputs, labels)
                        loss_hist_test += float(loss_test.item()) / (k * batch_size)
                        acc_hist_test += float(metric(outputs, labels)) / (k * batch_size)
                else:

                    loss_hist_test = -1
                    acc_hist_test = -1

            print(
                f'epoch[{epoch}]: loss: {loss_hist:9.4f} | {metric_name}: {acc_hist:9.4f} | val loss: {loss_hist_val:9.4f} | val {metric_name}:{acc_hist_val:9.4f}')
            print('=' * 100)
            ranks = []
            for i, l in enumerate(NN.lr_model):
                if hasattr(l, 'dynamic_rank'):
                    print(f'rank layer {i} {l.dynamic_rank}')
                    ranks.append(l.dynamic_rank)
            print('\n')

            params_test = count_params_test(NN, count_bias)
            # print(f'total params {total_params_full}, params test {params_test}')
            # input()
            cr_test = round(params_test / total_params_full, 3)
            params_train = count_params_train(NN, count_bias)
            cr_train = round(params_train / total_params_full, 3)
            params_train_grads = count_params_train(NN, count_bias, True)
            cr_train_grads = round(params_train_grads / total_params_full_grads, 3)
            # print(f'total params grads {total_params_full_grads}, params grad dlrt {params_train_grads}')
            # input()
            dist_M_I = max_dist_M_I(NN) if not optimizer.baseline else 'NA'

            print(f'cr: test {round(100 * (1 - cr_test), 4)} train_grads {round(100 * (1 - cr_train_grads), 4)}')

            compression_hyperparam = NN.cr if hasattr(NN, 'cr') else NN.tau
            epoch_data = [epoch, compression_hyperparam, dist_M_I, round(optimizer.tau, 5), round(loss_hist, 3),
                          round(acc_hist * 100, 4), round(loss_hist_val, 3), \
                          round(acc_hist_val * 100, 4), round(acc_hist_test * 100, 4), ranks, params_test,
                          round(100 * (1 - cr_test), 4), \
                          params_train, round(100 * (1 - cr_train), 4), params_train_grads,
                          round(100 * (1 - cr_train_grads), 4), average_batch_time]

            running_data.loc[epoch] = epoch_data

            if file_name is not None and (epoch % 5 == 0 or epoch == epochs - 1) and save_progress:

                running_data.to_csv(path + save_name + '.csv')
                try:
                    running_data.to_csv(path + '/drive/MyDrive/nips2023_results/' + save_name + '.csv')
                except:
                    print('drive not found')

            if scheduler is not None:
                scheduler.step(loss_hist)

            if epoch == 0:
                best_val_loss = loss_hist_val

            if loss_hist_val < best_val_loss and save_weights:
                print('save')
                # print(list(NN.lr_model.state_dict().values())[0][0,0])
                torch.save(NN, path + save_name + '.pt')
                best_val_loss = loss_hist_val
                try:
                    torch.save(NN, path + '/drive/MyDrive/nips2023_results/' + save_name + '.pt')
                except:
                    print('drive not found')

        return running_data

    else:

        if path is not None:
            file_name += '_finetune.csv'

        for epoch in tqdm(range(epochs)):

            print(f'epoch {epoch}---------------------------------------------')
            loss_hist = 0
            acc_hist = 0
            batch_size = train_loader.batch_size
            k = len(train_loader)
            average_batch_time = 0.0

            NN.train()
            for i, data in enumerate(train_loader):  # train
                NN.zero_grad()
                optimizer.zero_grad()
                start = time.time()
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = NN(inputs).to(device)
                loss = criterion(outputs, labels)
                loss.backward()
                loss_hist += float(loss.item()) / (k * batch_size)
                acc_hist += float(metric(outputs.detach(), labels)) / (k * batch_size)
                optimizer.S_finetune_step()
                stop = time.time() - start
                average_batch_time += stop / k

            NN.eval()
            with torch.no_grad():
                k = len(validation_loader)
                loss_hist_val = 0.0
                acc_hist_val = 0.0
                batch_size = validation_loader.batch_size
                for i, data in enumerate(validation_loader):  # validation
                    inputs, labels = data
                    inputs, labels = inputs.to(device), labels.to(device)
                    outputs = NN(inputs).detach().to(device)
                    loss_val = criterion(outputs, labels)
                    loss_hist_val += float(loss_val.item()) / (k * batch_size)
                    acc_hist_val += float(metric(outputs, labels)) / (k * batch_size)

                if test_loader != None:
                    k = len(test_loader)
                    loss_hist_test = 0.0
                    acc_hist_test = 0.0
                    batch_size = test_loader.batch_size
                    for i, data in enumerate(test_loader):  # validation
                        inputs, labels = data
                        inputs, labels = inputs.to(device), labels.to(device)
                        outputs = NN(inputs).detach().to(device)
                        loss_test = criterion(outputs, labels)
                        loss_hist_test += float(loss_test.item()) / (k * batch_size)
                        acc_hist_test += float(metric(outputs, labels)) / (k * batch_size)

                else:

                    loss_hist_test = -1
                    acc_hist_test = -1

            print(
                f'epoch[{epoch}]: loss: {loss_hist:9.4f} | {metric_name}: {acc_hist:9.4f} | val loss: {loss_hist_val:9.4f} | val {metric_name}:{acc_hist_val:9.4f}')
            print('=' * 100)
            ranks = []
            for i, l in enumerate(NN.lr_model):
                if hasattr(l, 'lr') and l.lr:
                    print(f'rank layer {i} {l.dynamic_rank}')
                    ranks.append(l.dynamic_rank)
            print('\n')

            if file_name is not None and (epoch % 5 == 0 or epoch == epochs - 1) and save_progress:

                running_data.to_csv(path + save_name + '.csv')
                try:
                    running_data.to_csv(path + '/drive/MyDrive/nips2023_results/' + save_name + '.csv')
                except:
                    print('drive not found')

            if scheduler is not None:
                scheduler.step(loss_hist)

            if epoch == 0:
                best_val_loss = loss_hist_val

            if loss_hist_val < best_val_loss and save_weights:
                print('save')
                # print(list(NN.lr_model.state_dict().values())[0][0,0])
                torch.save(NN, path + save_name + '.pt')
                best_val_loss = loss_hist_val
                try:
                    torch.save(NN, path + '/drive/MyDrive/nips2023_results/' + save_name + '.pt')
                except:
                    print('drive not found')

        return running_data
